PyTorch LightningのTrainerの仕組み
code: python
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# calls hooks like this one
on_train_batch_start()
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
losses.append(loss)
LightningModuleのフックの一覧でもある
pl.LightningModuleの__call__はnn.Moduleに定義された__call__
語弊を恐れずに言えば、forwardメソッドを呼び出す